package cn.dg.manager.common.interceptor;

import java.io.StringReader;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;

import cn.dg.manager.common.entity.PageEntity;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.JSqlParser;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;

/**
 * 
 * 分页拦截器，用于拦截需要进行分页查询的操作，然后对其进行分页处理。 利用拦截器实现Mybatis分页的原理：
 * 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象，Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象，而且对应的Sql语句
 * 是在Statement之前产生的，所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的
 * prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法，然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句，之后再调用
 * StatementHandler对象的prepare方法，即调用invocation.proceed()。
 * 对于分页而言，在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少，这是通过获取到了原始的Sql语句后，把它改为对应的统计语句再利用Mybatis封装好的参数和设
 * 置参数的功能把Sql语句中的参数进行替换，之后再执行查询记录数的Sql语句进行总记录数的统计。
 * 
 */
@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class})})
public class PageInterceptor implements Interceptor
{
    /**
     * 数据库类型，不同的数据库有不同的分页方法
     */
    private String databaseType;
    
    // mapper.xml中需要拦截的ID(正则匹配)
    private String pageSqlId; 
    
    /**
     * 拦截后要执行的方法
     */
    public Object intercept(Invocation invocation)
        throws Throwable
    {
        // 对于StatementHandler其实只有两个实现类，一个是RoutingStatementHandler，另一个是抽象类BaseStatementHandler，
        // BaseStatementHandler有三个子类，分别是SimpleStatementHandler，PreparedStatementHandler和CallableStatementHandler，
        // SimpleStatementHandler是用于处理Statement的，PreparedStatementHandler是处理PreparedStatement的，而CallableStatementHandler是
        // 处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler，而在RoutingStatementHandler里面拥有一个
        // StatementHandler类型的delegate属性，RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler，即SimpleStatementHandler、
        // PreparedStatementHandler或CallableStatementHandler，在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。
        // 我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法，又因为Mybatis只有在建立RoutingStatementHandler的时候
        // 是通过Interceptor的plugin方法进行包裹的，所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。
    	
    	if (invocation.getTarget() instanceof RoutingStatementHandler) {
	        RoutingStatementHandler handler = (RoutingStatementHandler)invocation.getTarget();
	        // 通过反射获取到当前RoutingStatementHandler对象的delegate属性
	        StatementHandler delegate = (StatementHandler)ReflectUtil.getFieldValue(handler, "delegate");
	        // 通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
            MappedStatement mappedStatement = (MappedStatement)ReflectUtil.getFieldValue(delegate, "mappedStatement");
			// System.err.println("============================================================");
			// System.err.println(mappedStatement.getId());
			// System.err.println(pageSqlId);
			// System.err.println(mappedStatement.getId().matches(pageSqlId));
			// System.err.println("============================================================");
			if (mappedStatement.getId().matches(pageSqlId)) {
				// 获取到当前StatementHandler的
				// boundSql，这里不管是调用handler.getBoundSql()还是直接调用delegate.getBoundSql()结果是一样的，因为之前已经说过了
				// RoutingStatementHandler实现的所有StatementHandler接口方法里面都是调用的delegate对应的方法。
				BoundSql boundSql = delegate.getBoundSql();

				// 拿到当前绑定Sql的参数对象，就是我们在调用对应的Mapper映射语句时所传入的参数对象
				Object obj = boundSql.getParameterObject();

				// 这里我们简单的通过传入的是Page对象就认定它是需要进行分页操作的。
				if (obj instanceof PageEntity<?>) {
					PageEntity<?> page = (PageEntity<?>) obj;

					// 拦截到的prepare方法参数是一个Connection对象
					Connection connection = (Connection) invocation.getArgs()[0];

					// 获取当前要执行的Sql语句，也就是我们直接在Mapper映射语句中写的Sql语句
					String sql = boundSql.getSql();

					DataAuthoritionSql authoritionSql;
					authoritionSql = new DataAuthoritionSql();
					authoritionSql.setSql(sql);

					// 给当前的page参数对象设置总记录数
					this.setTotalRecord(authoritionSql, page, mappedStatement, connection);

					// 获取分页Sql语句
					String pageSql = this.getPageSql(page, authoritionSql.getSql());

					// 利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
					ReflectUtil.setFieldValue(boundSql, "sql", pageSql);

					// 是否为空记录SQL
					if (authoritionSql.isNull()) {
						// 利用反射设置当前BoundSql对应的parameterMappings属性为NULL
						ReflectUtil.setFieldValue(boundSql, "parameterMappings", null);

						// 利用反射设置当前BoundSql对应的parameterObject属性为NULL
						ReflectUtil.setFieldValue(boundSql, "parameterObject", null);
					}
				}
			}
    	}
        return invocation.proceed();
    }
    
    /**
     * 拦截器对应的封装原始对象的方法
     */
    public Object plugin(Object target)
    {
        return Plugin.wrap(target, this);
    }
    
    /**
     * 设置注册拦截器时设定的属性
     */
    public void setProperties(Properties properties)
    {
        this.databaseType = properties.getProperty("databaseType");
        
        this.pageSqlId = properties.getProperty("pageSqlId");
    }
    
    /**
     * 解析SQL，获取所有的查询字段
     * 
     * @param sql
     * @throws JSQLParserException
     * @see [类、类#方法、类#成员]
     */
    private List<String> parseSqlSelectColumns(String sql)
        throws JSQLParserException
    {
        if (StringUtils.isEmpty(sql))
        {
            return null;
        }
        
        List<String> selectColumns = new ArrayList<String>();
        JSqlParser parser = new CCJSqlParserManager();
        Statement statement = parser.parse(new StringReader(sql));
        
        if (statement instanceof Select)
        {
            Select select = (Select)statement;
            PlainSelect plain = (PlainSelect)select.getSelectBody();
            List<SelectItem> selectItems = plain.getSelectItems();
            
            if (CollectionUtils.isEmpty(selectItems))
            {
                return null;
            }
            
            // 字段别名
            String alias;
            Expression expression;
            for (SelectItem selectItem : selectItems)
            {
                if (selectItem instanceof SelectExpressionItem)
                {
                    alias = ((SelectExpressionItem)selectItem).getAlias();
                    
                    if (!StringUtils.isEmpty(alias))
                    {
                        selectColumns.add(alias);
                    }
                    else
                    {
                        expression = ((SelectExpressionItem)selectItem).getExpression();
                        if (expression instanceof Column)
                        {
                            selectColumns.add(((Column)expression).getColumnName());
                        }
                    }
                }
            }
        }
        
        return selectColumns;
    }
    
    /**
     * 构建空记录SQL语句
     * 
     * @param selectColumns
     * @return
     * @see [类、类#方法、类#成员]
     */
    private String getNullSql(List<String> selectColumns)
    {
        if (CollectionUtils.isEmpty(selectColumns))
        {
            return "SELECT 1 FROM DUAL WHERE 0=1";
        }
        
        StringBuffer buffer = new StringBuffer();
        buffer.append("SELECT");
        for (String column : selectColumns)
        {
            buffer.append(" '' AS ").append(column).append(",");
        }
        
        buffer.deleteCharAt(buffer.length() - 1);
        buffer.append(" FROM DUAL WHERE 0=1");
        
        return buffer.toString();
    }
    
    /**
     * 根据page对象获取对应的分页查询Sql语句，这里只做了两种数据库类型，Mysql和Oracle 其它的数据库都 没有进行分页
     * 
     * @param page 分页对象
     * @param sql 原始sql
     * @return
     * @throws ClassNotFoundException
     */
    private String getPageSql(PageEntity<?> page, String sql)
    {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        if ("mysql".equalsIgnoreCase(databaseType))
        {
            return getMysqlPageSql(page, sqlBuffer);
        }
        else if ("oracle".equalsIgnoreCase(databaseType))
        {
            return getOraclePageSql(page, sqlBuffer);
        }
        return sqlBuffer.toString();
    }
    
    /**
     * 获取Mysql数据库的分页查询语句
     * 
     * @param page 分页对象
     * @param sqlBuffer 包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句
     */
    private String getMysqlPageSql(PageEntity<?> page, StringBuffer sqlBuffer)
    {
        // 计算第一条记录的位置，Mysql中记录的位置是从0开始的。
        int offset = (page.getPageNo() - 1) * page.getPageSize();
        sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
        return sqlBuffer.toString();
    }
    
    /**
     * 获取Oracle数据库的分页查询语句
     * 
     * @param page 分页对象
     * @param sqlBuffer 包含原sql语句的StringBuffer对象
     * @return Oracle数据库的分页查询语句
     */
    private String getOraclePageSql(PageEntity<?> page, StringBuffer sqlBuffer)
    {
        // 计算第一条记录的位置，Oracle分页是通过rownum进行的，而rownum是从1开始的
        int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
        sqlBuffer.insert(0, "select u.*, rownum r from (")
            .append(") u where rownum < ")
            .append(offset + page.getPageSize());
        sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
        // 上面的Sql语句拼接之后大概是这个样子：
        // select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
        return sqlBuffer.toString();
    }
    
    /**
     * 给当前的参数对象page设置总记录数
     * 
     * @param authoritionSql 权限sql对象
     * @param page Mapper映射语句对应的参数对象
     * @param mappedStatement Mapper映射语句
     * @param connection 当前的数据库连接
     * @throws SQLException
     */
    private void setTotalRecord(DataAuthoritionSql authoritionSql, PageEntity<?> page, MappedStatement mappedStatement,
        Connection connection)
            throws SQLException
    {
        // 获取对应的BoundSql，这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
        // delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
        BoundSql boundSql = mappedStatement.getBoundSql(page);
        
        // 通过查询Sql语句获取到对应的计算总记录数的sql语句
        String countSql = this.getCountSql(authoritionSql.getSql());
        
        // 利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,
            authoritionSql.isNull() ? null : boundSql.getParameterMappings(), authoritionSql.isNull() ? null : page);
            
        // 通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
        
        // 通过connection建立一个countSql对应的PreparedStatement对象。
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try
        {
            pstmt = connection.prepareStatement(countSql);
            // 通过parameterHandler给PreparedStatement对象设置参数
            parameterHandler.setParameters(pstmt);
            // 之后就是执行获取总记录数的Sql语句和获取结果了。
            rs = pstmt.executeQuery();
            if (rs.next())
            {
                int totalRecord = rs.getInt(1);
                // 给当前的参数page对象设置总记录数
                page.setTotalRecord(totalRecord);
            }
        }
        catch (SQLException e)
        {
            e.printStackTrace();
            throw e;
        }
        finally
        {
            try
            {
                if (rs != null)
                    rs.close();
                if (pstmt != null)
                    pstmt.close();
            }
            catch (SQLException e)
            {
                e.printStackTrace();
            }
        }
    }
    
    /**
     * 根据原Sql语句获取对应的查询总记录数的Sql语句
     * 
     * @param sql
     * @return
     */
    private String getCountSql(String sql)
    {
        return "SELECT count(0) FROM (" + sql + ") COUNT_ALIAS";
    }
    
    private static class ReflectUtil
    {
        /**
         * 利用反射获取指定对象的指定属性
         * 
         * @param obj 目标对象
         * @param fieldName 目标属性
         * @return 目标属性的值
         */
        public static Object getFieldValue(Object obj, String fieldName)
        {
            Object result = null;
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null)
            {
                field.setAccessible(true);
                try
                {
                    result = field.get(obj);
                }
                catch (IllegalArgumentException e)
                {
                    e.printStackTrace();
                }
                catch (IllegalAccessException e)
                {
                    e.printStackTrace();
                }
            }
            return result;
        }
        
        /**
         * 利用反射获取指定对象里面的指定属性
         * 
         * @param obj 目标对象
         * @param fieldName 目标属性
         * @return 目标字段
         */
        private static Field getField(Object obj, String fieldName)
        {
            Field field = null;
            for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass())
            {
                try
                {
                    field = clazz.getDeclaredField(fieldName);
                    break;
                }
                catch (NoSuchFieldException e)
                {
                    // 这里不用做处理，子类没有该字段可能对应的父类有，都没有就返回null。
                }
            }
            return field;
        }
        
        /**
         * 利用反射设置指定对象的指定属性为指定的值
         * 
         * @param obj 目标对象
         * @param fieldName 目标属性
         * @param fieldValue 目标值
         */
        public static void setFieldValue(Object obj, String fieldName, String fieldValue)
        {
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null)
            {
                try
                {
                    field.setAccessible(true);
                    field.set(obj, fieldValue);
                }
                catch (IllegalArgumentException e)
                {
                    e.printStackTrace();
                }
                catch (IllegalAccessException e)
                {
                    e.printStackTrace();
                }
            }
        }
    }
    
}
